{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "gFaYLBeNPT4R"
   },
   "source": [
    "# tf-explain\n",
    "\n",
    "[![Pypi Version](https://img.shields.io/pypi/v/tf-explain.svg)](https://pypi.org/project/tf-explain/)\n",
    "[![Build Status](https://api.travis-ci.org/sicara/tf-explain.svg?branch=master)](https://travis-ci.org/sicara/tf-explain)\n",
    "[![Documentation Status](https://readthedocs.org/projects/tf-explain/badge/?version=latest)](https://tf-explain.readthedocs.io/en/latest/?badge=latest)\n",
    "![Python Versions](https://img.shields.io/badge/python-3.6%20|%203.7-%23EBBD68.svg)\n",
    "![Tensorflow Versions](https://img.shields.io/badge/tensorflow-2.0.0-blue.svg)\n",
    "[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/python/black)\n",
    "\n",
    "__tf-explain__ implements interpretability methods as Tensorflow 2.0 callbacks to __ease neural network's understanding__.  \n",
    "See [Introducing tf-explain, Interpretability for Tensorflow 2.0](https://blog.sicara.com/tf-explain-interpretability-tensorflow-2-9438b5846e35)\n",
    "\n",
    "__Documentation__: https://tf-explain.readthedocs.io"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "U3Yrxc7tP4Z-"
   },
   "source": [
    "## tf-explain example over mnist/fashion-mnist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "id": "h0OLyf2pMWm8",
    "outputId": "3b556b1d-e044-41e9-e887-39509220b4d1"
   },
   "outputs": [],
   "source": [
    "try:\n",
    "  # %tensorflow_version only exists in Colab.\n",
    "  %tensorflow_version 2.x\n",
    "except Exception:\n",
    "  pass\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 255
    },
    "colab_type": "code",
    "id": "lYAkOjobL1Xx",
    "outputId": "8b482029-42a0-4e1e-cb14-b498a4858918"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[31mERROR: File \"setup.py\" not found. Directory cannot be installed in editable mode: /home/raph/Sicara/tf-explain/examples/callbacks\u001b[0m\n",
      "\u001b[33mWARNING: You are using pip version 19.2.3, however version 19.3.1 is available.\n",
      "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "pip install -e ."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "EozpISZ1L8vd"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tf_explain"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "k1fnyKktNIs4"
   },
   "outputs": [],
   "source": [
    "AVAILABLE_DATASETS = {\n",
    "    'mnist': tf.keras.datasets.mnist,\n",
    "    'fashion_mnist': tf.keras.datasets.fashion_mnist,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 35
    },
    "colab_type": "code",
    "id": "HDHeVwweM7EC",
    "outputId": "7076b8d0-7657-4156-80da-0b78c1f2b970"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mnist\n"
     ]
    }
   ],
   "source": [
    "#@CHOOSE A DATASET\n",
    "\n",
    "DATASET_NAME = 'mnist' #@param [\"fashion_mnist\", \"mnist\"]\n",
    "print(DATASET_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "cH80JCZiMlTO"
   },
   "outputs": [],
   "source": [
    "INPUT_SHAPE = (28, 28, 1)\n",
    "NUM_CLASSES = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 54
    },
    "colab_type": "code",
    "id": "R2FcVG-yNhHa",
    "outputId": "c8842bfa-1a89-402e-bea3-6636d03f6afd"
   },
   "outputs": [],
   "source": [
    "# Load dataset\n",
    "dataset = AVAILABLE_DATASETS[DATASET_NAME]\n",
    "(train_images, train_labels), (test_images, test_labels) = dataset.load_data()\n",
    "\n",
    "# Convert from (28, 28) images to (28, 28, 1)\n",
    "train_images = train_images[..., tf.newaxis].astype('float32')\n",
    "test_images = test_images[..., tf.newaxis].astype('float32')\n",
    "\n",
    "# One hot encore labels 0, 1, .., 9 to [0, 0, .., 1, 0, 0]\n",
    "train_labels = tf.keras.utils.to_categorical(train_labels, num_classes=NUM_CLASSES)\n",
    "test_labels = tf.keras.utils.to_categorical(test_labels, num_classes=NUM_CLASSES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "TogCHJ2qNxfX"
   },
   "outputs": [],
   "source": [
    "# Create model\n",
    "img_input = tf.keras.Input(INPUT_SHAPE)\n",
    "\n",
    "x = tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu')(img_input)\n",
    "x = tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation='relu', name='target_layer')(x)\n",
    "x = tf.keras.layers.MaxPool2D(pool_size=(2, 2))(x)\n",
    "\n",
    "x = tf.keras.layers.Dropout(0.25)(x)\n",
    "x = tf.keras.layers.Flatten()(x)\n",
    "\n",
    "x = tf.keras.layers.Dense(128, activation='relu')(x)\n",
    "x = tf.keras.layers.Dropout(0.5)(x)\n",
    "\n",
    "x = tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')(x)\n",
    "\n",
    "model = tf.keras.Model(img_input, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "hSKSkyY4Nzjd"
   },
   "outputs": [],
   "source": [
    "model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ZhPmDKGCN5LO"
   },
   "outputs": [],
   "source": [
    "# Select a subset of the validation data to examine\n",
    "# Here, we choose 5 elements with one hot encoded label \"0\" == [1, 0, 0, .., 0]\n",
    "validation_class_zero = (np.array([\n",
    "    el for el, label in zip(test_images, test_labels)\n",
    "    if np.all(np.argmax(label) == 0)\n",
    "][0:5]), None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "R3dqe5uaN6bt"
   },
   "outputs": [],
   "source": [
    "# Select a subset of the validation data to examine\n",
    "# Here, we choose 5 elements with one hot encoded label \"4\" == [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]\n",
    "validation_class_fours = (np.array([\n",
    "    el for el, label in zip(test_images, test_labels)\n",
    "    if np.all(np.argmax(label) == 4)\n",
    "][0:5]), None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "x5ZCysCvOBw1"
   },
   "outputs": [],
   "source": [
    "# Instantiate callbacks\n",
    "# class_index value should match the validation_data selected above\n",
    "callbacks = [\n",
    "    tf_explain.callbacks.GradCAMCallback(validation_class_zero, layer_name='target_layer', class_index=0),\n",
    "    tf_explain.callbacks.GradCAMCallback(validation_class_fours, layer_name='target_layer', class_index=4),\n",
    "    tf_explain.callbacks.ActivationsVisualizationCallback(validation_class_zero, layers_name=['target_layer']),\n",
    "    tf_explain.callbacks.SmoothGradCallback(validation_class_zero, class_index=0, num_samples=15, noise=1.),\n",
    "    tf_explain.callbacks.IntegratedGradientsCallback(validation_class_zero, class_index=0, n_steps=10),\n",
    "    tf_explain.callbacks.VanillaGradientsCallback(validation_class_zero, class_index=0),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "w1yXsgeOQR1O"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "      <iframe id=\"tensorboard-frame-4cf06318041e1d97\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
       "      </iframe>\n",
       "      <script>\n",
       "        (function() {\n",
       "          const frame = document.getElementById(\"tensorboard-frame-4cf06318041e1d97\");\n",
       "          const url = new URL(\"/\", window.location);\n",
       "          url.port = 6006;\n",
       "          frame.src = url;\n",
       "        })();\n",
       "      </script>\n",
       "  "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%load_ext tensorboard\n",
    "%tensorboard --logdir logs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 217
    },
    "colab_type": "code",
    "id": "_-4kcky6OOnc",
    "outputId": "1b9e0f26-d08b-466a-93ef-dc834227af53"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 60000 samples\n",
      "Epoch 1/5\n",
      "60000/60000 [==============================] - 136s 2ms/sample - loss: 0.4756 - accuracy: 0.8967\n",
      "Epoch 2/5\n",
      "60000/60000 [==============================] - 107s 2ms/sample - loss: 0.1403 - accuracy: 0.9590\n",
      "Epoch 3/5\n",
      "60000/60000 [==============================] - 75s 1ms/sample - loss: 0.1082 - accuracy: 0.9691\n",
      "Epoch 4/5\n",
      "19200/60000 [========>.....................] - ETA: 56s - loss: 0.0921 - accuracy: 0.9733"
     ]
    }
   ],
   "source": [
    "# Start training\n",
    "model.fit(train_images, train_labels, epochs=5, callbacks=callbacks)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "name": "tf-explain-mnist.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
